package com.bayesianWeka;

import weka.core.*;

import weka.classifiers.AbstractClassifier;

import java.text.DecimalFormat;
import java.util.Arrays;

public class AIWNB_E extends AbstractClassifier {

	private int m_NumClasses;

	private int m_NumInstances;

	private int m_NumAttributes;

	private int[] m_NumAiValue;	//属性值的个数

	private double[][] m_AttributesDensity;

	private int[][] m_CountAttributes;

	private int m_ClassIndex;

	private double[] m_CountClass;

	private double[][][] m_CountClassAi;

	/** for serialization */
//	static final long serialVersionUID = -4503874444306113214L;

	/** The number of each class value occurs in the dataset */
	private double[] m_ClassCounts;

	/** The number of each attribute value occurs in the dataset */
	private double[] m_AttCounts;

	/** The number of two attributes values occurs in the dataset */
	private double[][] m_AttAttCounts;

	/** The number of class and two attributes values occurs in the dataset */
	private double[][][] m_ClassAttAttCounts;

	/** The number of class and one attribute value occurs in the dataset */
	private double[][] m_ClassAttCounts;

	/** The number of values for each attribute in the dataset */
	private int[] m_NumAttValues;

	/** The number of values for all attributes in the dataset */
	private int m_TotalAttValues;

	/** The starting index of each attribute in the dataset */
	private int[] m_StartAttIndex;

	/** The 2D array of the mutual information of each pair attributes */
	private double[][] m_mutualInformation;

	private double[] Attr_Weight;
	
	private double[] Ins_Weight;

	private double[][] m_conditionalMutualInformationYxx;

	public int getMaxValue(int[] arr)// ,int start,int end
	{
		int max = arr[0];
		for (int j = 1; j < arr.length; j++) {
			if (max < arr[j])
				max = arr[j];
		}
		return max;
	}


	public void buildClassifier(Instances instances) throws Exception {

		m_ClassIndex = instances.classIndex();
		m_NumAttributes = instances.numAttributes();
		m_NumInstances = instances.numInstances();
		m_NumClasses = instances.numClasses();

		Ins_Weight = new double[m_NumInstances];
		m_CountClass = new double[m_NumClasses];

		m_NumAiValue = new int[m_NumAttributes];
		for (int i = 0; i < m_NumAttributes; i++)
			m_NumAiValue[i] = instances.attribute(i).numValues();
		
		int maxNumAttribute = getMaxValue(m_NumAiValue);
		m_CountClassAi = new double[m_NumClasses][m_NumAttributes - 1][maxNumAttribute];

		m_AttributesDensity = new double[m_NumAttributes][maxNumAttribute];
		m_CountAttributes = new int[m_NumAttributes][maxNumAttribute];

		for (int i = 0; i < m_NumInstances; i++) {
			for (int j = 0; j < m_NumAttributes; j++) {
				int ajValue = (int)instances.instance(i).value(j);
				m_CountAttributes[j][ajValue]++;
			}
		}

		for (int i = 0; i < m_NumAttributes; i++) {
			for (int j = 0; j < m_NumAiValue[i]; j++) {
				m_AttributesDensity[i][j] = ((double) m_CountAttributes[i][j]) / ((double) m_NumInstances);// δʹ��������˹���ƣ���ʱע�����ǿ��ת����
			}
		}
		
		
		for (int i = 0; i < m_NumInstances; i++) {
			for (int j = 0; j < m_NumAttributes; j++) {
				if (j != m_ClassIndex) {
					int ajValue = (int) instances.instance(i).value(j);
					Ins_Weight[i] += m_AttributesDensity[j][ajValue] * m_NumAiValue[j];
				}
			}
		}

		for (int i = 0; i < m_NumInstances; i++) {
			int classValue = (int) instances.instance(i).classValue();
			m_CountClass[classValue] += Ins_Weight[i];

			for (int j = 0; j < m_NumAttributes; j++) {
				if (j != m_ClassIndex) {
					int ajVlaue = (int) instances.instance(i).value(j);
					m_CountClassAi[classValue][j][ajVlaue] += Ins_Weight[i];
				}
			}
		}

		Attr_Weight = new double[instances.numAttributes()];

		m_StartAttIndex = new int[m_NumAttributes];
		m_NumAttValues = new int[m_NumAttributes];

		// set the starting index of each attribute and the number of values for
		// each attribute and the total number of values for all attributes (not
		// including class).
		for (int i = 0; i < m_NumAttributes; i++) {
			m_StartAttIndex[i] = m_TotalAttValues;
			m_NumAttValues[i] = instances.attribute(i).numValues();
			m_TotalAttValues += m_NumAttValues[i];
		}

		// allocate space for counts and frequencies
		m_ClassCounts = new double[m_NumClasses];
		m_AttCounts = new double[m_TotalAttValues];
		m_AttAttCounts = new double[m_TotalAttValues][m_TotalAttValues];
		m_ClassAttCounts = new double[m_NumClasses][m_TotalAttValues];
		m_ClassAttAttCounts = new double[m_NumClasses][m_TotalAttValues][m_TotalAttValues];

		// Calculate the counts
		for (int k = 0; k < m_NumInstances; k++) {
			int classVal = (int) instances.instance(k).classValue();
			m_ClassCounts[classVal]++;
			int[] attIndex = new int[m_NumAttributes];
			for (int i = 0; i < m_NumAttributes; i++) {
				attIndex[i] = m_StartAttIndex[i] + (int) instances.instance(k).value(i);
				m_AttCounts[attIndex[i]]++;
			}
			for (int Att1 = 0; Att1 < m_NumAttributes; Att1++) {
				for (int Att2 = 0; Att2 < m_NumAttributes; Att2++) {
					m_AttAttCounts[attIndex[Att1]][attIndex[Att2]]++;
					m_ClassAttAttCounts[classVal][attIndex[Att1]][attIndex[Att2]]++;
				}
				m_ClassAttCounts[classVal][attIndex[Att1]]++;
			}
		}

		// compute mutual information between each pair attributes (including
		// class)
		m_mutualInformation = new double[m_NumAttributes][m_NumAttributes];
		for (int att1 = 0; att1 < m_NumAttributes; att1++) {
			for (int att2 = att1 + 1; att2 < m_NumAttributes; att2++) {
				m_mutualInformation[att1][att2] = mutualInfo(att1, att2);
				m_mutualInformation[att2][att1] = m_mutualInformation[att1][att2];
			}
		}

		m_conditionalMutualInformationYxx = getConditionMutualInfor_yxx();

//		System.out.println("实例个数为："+m_NumInstances);
//		System.out.println("m_ClassIndex的值为："+m_ClassIndex);

//		DecimalFormat df = new DecimalFormat("0.00000");
//		System.out.println("----------------------------");
//		for (int att1 = 0; att1 < m_NumAttributes; att1++) {
//			for (int att2 = 0; att2 < m_NumAttributes; att2++) {
//				if (att1 != m_ClassIndex && att2 != m_ClassIndex)
//					System.out.print(df.format(m_conditionalMutualInformationYxx[att1][att2]) + "   ");
//			}
//			System.out.println();
//		}
//		System.out.println("----------------------------");

		double ave = 0;
		for (int att1 = 0; att1 < m_NumAttributes; att1++) {
			if (att1 == m_ClassIndex)
				continue;
			ave += m_mutualInformation[att1][m_ClassIndex];
		}
		ave /= (m_NumAttributes - 1);
		for (int att1 = 0; att1 < m_NumAttributes; att1++) {
			if (att1 == m_ClassIndex)
				continue;
			m_mutualInformation[att1][m_ClassIndex] /= ave;
		}
		double mean = 0;
		for (int att1 = 0; att1 < m_NumAttributes; att1++) {
			if (att1 == m_ClassIndex)
				continue;
			for (int att2 = 0; att2 < m_NumAttributes; att2++) {
				if (att2 == m_ClassIndex || att2 == att1)
					continue;
				mean += m_mutualInformation[att1][att2];
			}
		}
		mean /= ((m_NumAttributes - 1) * (m_NumAttributes - 2));
		for (int att1 = 0; att1 < m_NumAttributes; att1++) {
			if (att1 == m_ClassIndex)
				continue;
			for (int att2 = 0; att2 < m_NumAttributes; att2++) {
				if (att2 == m_ClassIndex || att2 == att1)
					continue;
				m_mutualInformation[att1][att2] /= mean;
			}
		}

		double[] aveMutualInfo = new double[m_NumAttributes];
		for (int att1 = 0; att1 < m_NumAttributes; att1++) {
			if (att1 == m_ClassIndex)
				continue;
			for (int att2 = 0; att2 < m_NumAttributes; att2++) {
				if (att2 == m_ClassIndex || att2 == att1)
					continue;
				aveMutualInfo[att1] += m_mutualInformation[att1][att2];
			}
			aveMutualInfo[att1] /= (m_NumAttributes - 2);
		}

		Attr_Weight = new double[m_NumAttributes];
		for (int att1 = 0; att1 < m_NumAttributes; att1++) {
			if (att1 == m_ClassIndex)
				continue;
			Attr_Weight[att1] = 1 / (1 + Math.exp(-(m_mutualInformation[att1][m_ClassIndex] - aveMutualInfo[att1])));
		}

	}

	/**
	 * Calculates the class membership probabilities for the given test instance
	 *
	 * @param instance the instance to be classified
	 * @return predicted class probability distribution
	 * @exception Exception if there is a problem generating the prediction
	 */
	public double[] distributionForInstance(Instance instance) throws Exception {

		double[] probs = new double[m_NumClasses];

		int[] aiValue = new int[m_NumAttributes];
		for (int i = 0; i < m_NumAttributes; i++)
			aiValue[i] = (int) instance.value(i);

		for (int i = 0; i < m_NumClasses; i++) {
			probs[i] = (m_CountClass[i] + 1.0 / m_NumClasses) / (m_NumInstances + 1.0);

			for (int j = 0; j < m_NumAttributes; j++) {
				if (j != m_ClassIndex) {
					probs[i] *= Math.pow(
							(m_CountClassAi[i][j][aiValue[j]] + 1.0 / m_NumAiValue[j]) / (m_CountClass[i] + 1.0),
							Attr_Weight[j]);
				}
			}
		}
		Utils.normalize(probs);
		return probs;
	}

	private double mutualInfo(int att1, int att2) throws Exception {

		double mutualInfo = 0;
		int attIndex1 = m_StartAttIndex[att1];
		int attIndex2 = m_StartAttIndex[att2];
		double[] PriorsAtt1 = new double[m_NumAttValues[att1]];	//P(Xi)
		double[] PriorsAtt2 = new double[m_NumAttValues[att2]];	//P(Xj)
		double[][] PriorsAtt1Att2 = new double[m_NumAttValues[att1]][m_NumAttValues[att2]]; //P(Xi,Xj)

		for (int i = 0; i < m_NumAttValues[att1]; i++) {
			PriorsAtt1[i] = m_AttCounts[attIndex1 + i] / m_NumInstances;
		}

		for (int j = 0; j < m_NumAttValues[att2]; j++) {
			PriorsAtt2[j] = m_AttCounts[attIndex2 + j] / m_NumInstances;
		}

		for (int i = 0; i < m_NumAttValues[att1]; i++) {
			for (int j = 0; j < m_NumAttValues[att2]; j++) {
				PriorsAtt1Att2[i][j] = m_AttAttCounts[attIndex1 + i][attIndex2 + j] / m_NumInstances;
			}
		}

		for (int i = 0; i < m_NumAttValues[att1]; i++) {
			for (int j = 0; j < m_NumAttValues[att2]; j++) {
				mutualInfo += PriorsAtt1Att2[i][j] * log2(PriorsAtt1Att2[i][j], PriorsAtt1[i] * PriorsAtt2[j]);
			}
		}
		return mutualInfo;
	}

	private double log2(double x, double y) {

		if (x < 1e-6 || y < 1e-6)
			return 0.0;
		else
			return Math.log(x / y) / Math.log(2);
	}

	public double[][] getConditionMutualInfor_yxx(){
		double[][] cmiYxx = new double[m_NumAttributes][m_NumAttributes];

		for (int atti = 0; atti < m_NumAttributes; atti++) {        //属性Xi
			for (int attj = 0; attj < m_NumAttributes; attj++) {    //属性Xj
				double cmiyxx = 0.0;
//				if (atti == attj)
//					continue;    //如果xi和xj相等那就直接跳过循环
				int attIndexi = m_StartAttIndex[atti];        //xi在矩阵中的开始索引
				int attIndexj = m_StartAttIndex[attj];        //xj在矩阵中的开始索引

				for (int value_i = 0; value_i < m_NumAttValues[atti]; value_i++) {    //xi的属性值
					for (int value_j = 0; value_j < m_NumAttValues[attj]; value_j++) {    //xj的属性值
						for (int classAtt = 0; classAtt < m_NumClasses; classAtt++) {
							if (m_ClassAttAttCounts[classAtt][attIndexi + value_i][attIndexj + value_j] != 0) {
								//如果Count(y,xi,xj)的值不为0
								cmiyxx += (m_ClassAttAttCounts[classAtt][attIndexi + value_i][attIndexj + value_j] / m_NumInstances) *
										log2(m_ClassAttAttCounts[classAtt][attIndexi + value_i][attIndexj + value_j] * m_AttCounts[attIndexj + value_j],
												m_AttAttCounts[attIndexi + value_i][attIndexj + value_j] * m_ClassAttCounts[classAtt][attIndexj + value_j]);
							}
						}
					}
				}
				cmiYxx[atti][attj] = cmiyxx;
			}
		}

		return cmiYxx;
	}
	
	public static void main(String[] argv) {
		runClassifier(new AIWNB_E(), argv);
	}

}
